{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ceea3596-1ff6-4b69-8ed0-a91c68951b04",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qid\ttitle\n",
      "0\t非法本司法考试怎么复习？\n",
      "1\t男友性欲太强是一种怎样的体验?\n",
      "2\t我国的蒙古族同胞是如何看待外蒙古国的？\n",
      "3\t把游戏里没有聚能设计的C4炸药放在现实中，真的能炸坦克吗？\n",
      "4\t如何评价郎咸平不但从空姐小三那里要回了房子，而且让小三倒赔900万？\n",
      "5\t深度学习相关的职位面试时一般会问什么？会问一些传统的机器学习算法吗？\n",
      "6\t大家有兴趣来预测及描绘腾讯的最终产品形态吗？\n",
      "7\t有哪些不论你再难过，听后也会笑的笑话？\n",
      "8\t对镜头设计而言，法兰距是短好还是长好？\n"
     ]
    }
   ],
   "source": [
    "!head train.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2d921b67-92a8-4d41-93af-1ab112b15242",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qid,topic\n",
      "0,0\n",
      "1,1\n",
      "2,2\n",
      "3,3\n",
      "4,4\n",
      "5,5\n",
      "6,6\n",
      "7,7\n",
      "8,8\n"
     ]
    }
   ],
   "source": [
    "!head sample_submit.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "79a4f648-7449-47d6-b695-e10742249b74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>qid</th>\n",
       "      <th>title</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>非法本司法考试怎么复习？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>男友性欲太强是一种怎样的体验?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>我国的蒙古族同胞是如何看待外蒙古国的？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>把游戏里没有聚能设计的C4炸药放在现实中，真的能炸坦克吗？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>如何评价郎咸平不但从空姐小三那里要回了房子，而且让小三倒赔900万？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75115</th>\n",
       "      <td>75115</td>\n",
       "      <td>如何在嘈杂的都市中像村上春树笔下的人物一样生活？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75116</th>\n",
       "      <td>75116</td>\n",
       "      <td>万圣节聚会怎样打扮会比较好玩？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75117</th>\n",
       "      <td>75117</td>\n",
       "      <td>庞中华的硬笔书法水平如何？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75118</th>\n",
       "      <td>75118</td>\n",
       "      <td>琅琊榜中有哪些细思极恐的细节？</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75119</th>\n",
       "      <td>75119</td>\n",
       "      <td>景深和传感器尺寸到底有没有关系？</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>75120 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         qid                               title\n",
       "0          0                        非法本司法考试怎么复习？\n",
       "1          1                     男友性欲太强是一种怎样的体验?\n",
       "2          2                 我国的蒙古族同胞是如何看待外蒙古国的？\n",
       "3          3       把游戏里没有聚能设计的C4炸药放在现实中，真的能炸坦克吗？\n",
       "4          4  如何评价郎咸平不但从空姐小三那里要回了房子，而且让小三倒赔900万？\n",
       "...      ...                                 ...\n",
       "75115  75115            如何在嘈杂的都市中像村上春树笔下的人物一样生活？\n",
       "75116  75116                     万圣节聚会怎样打扮会比较好玩？\n",
       "75117  75117                       庞中华的硬笔书法水平如何？\n",
       "75118  75118                     琅琊榜中有哪些细思极恐的细节？\n",
       "75119  75119                    景深和传感器尺寸到底有没有关系？\n",
       "\n",
       "[75120 rows x 2 columns]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import jieba\n",
    "from tqdm import tqdm\n",
    "\n",
    "train = pd.read_csv('train.csv', sep='\\t')\n",
    "train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5fa9b397-6b5c-4b7e-9b54-9dc2b54b6003",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "772e77c610654a6faed27123c8767c48",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "\n",
    "from sentence_transformers import SentenceTransformer\n",
    "sentences_1 = [\"样例数据-1\", \"样例数据-2\"]\n",
    "model = SentenceTransformer('BAAI/bge-small-zh-v1.5')\n",
    "embeddings_1 = model.encode(sentences_1, normalize_embeddings=True, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "14b2acfd-01fc-4f09-b811-b2cddc2dcf57",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c88be018bd3c4f06b0c2755a0d72bea9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/2348 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "embeddings_1 = model.encode(train['title'].values, normalize_embeddings=True, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fc3e85aa-570c-43e7-853a-2e427e9343df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sklearn.cluster import MiniBatchKMeans\n",
    "\n",
    "# sse = []\n",
    "# for n_clusters in tqdm(range(100, 10000, 500)):\n",
    "#     km = MiniBatchKMeans(n_clusters=n_clusters, verbose=0, max_iter=1000)\n",
    "#     km.fit(embeddings_1)\n",
    "#     sse.append(km.inertia_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f9fe66c6-88cf-4c2e-8265-fffef50eaa66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt\n",
    "# plt.plot(\n",
    "#     list(range(100, 10000, 500))[:18],\n",
    "#     sse\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "92431b7c-37af-4319-9ff6-55e84557e11d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-2 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: black;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-2 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-2 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-2 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: block;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-2 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-2 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-2 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-2 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 1ex;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-2 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>MiniBatchKMeans(max_iter=2000, n_clusters=3000)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;MiniBatchKMeans<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.cluster.MiniBatchKMeans.html\">?<span>Documentation for MiniBatchKMeans</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>MiniBatchKMeans(max_iter=2000, n_clusters=3000)</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "MiniBatchKMeans(max_iter=2000, n_clusters=3000)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "km = MiniBatchKMeans(n_clusters=3000, verbose=0, max_iter=2000)\n",
    "km.fit(embeddings_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5e2e2ebb-c80a-4f1d-b20e-253fc556a327",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "31386.64453125"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "km.inertia_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9d5a21e8-f500-4c20-82ac-e18128746f9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train['topic'] = km.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bb882133-b5a6-4d8a-b5f6-f587b35133da",
   "metadata": {},
   "outputs": [],
   "source": [
    "train[['qid', 'topic']].to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5a5cc83a-302b-4447-ac75-cc0fa4d0dee2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------using 2*GPUs----------\n",
      "[[0.626  0.3477]\n",
      " [0.3499 0.678 ]]\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import BGEM3FlagModel\n",
    "\n",
    "model = BGEM3FlagModel('/home/lyz/hf-models/bge-m3/') \n",
    "\n",
    "sentences_1 = [\"What is BGE M3?\", \"Defination of BM25\"]\n",
    "sentences_2 = [\"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.\", \n",
    "               \"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document\"]\n",
    "\n",
    "embeddings_1 = model.encode(sentences_1, \n",
    "                            batch_size=12, \n",
    "                            max_length=8192, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.\n",
    "                            )['dense_vecs']\n",
    "embeddings_2 = model.encode(sentences_2)['dense_vecs']\n",
    "similarity = embeddings_1 @ embeddings_2.T\n",
    "print(similarity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "485080da-6945-46d8-ad45-ca570f34967f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Inference Embeddings: 100%|██████████| 376/376 [04:10<00:00,  1.50it/s]\n"
     ]
    }
   ],
   "source": [
    "embeddings_1 = model.encode(list(train['title'].values), batch_size=100)['dense_vecs']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d385ee5c-9ee6-4d4f-bd56-246433f767a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import normalize\n",
    "embeddings_1 = normalize(embeddings_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "6ba514c7-13a3-4a9f-a8d9-20b9b627e2f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "km = MiniBatchKMeans(n_clusters=2500, verbose=0, max_iter=2000)\n",
    "km.fit(embeddings_1)\n",
    "train['topic'] = km.labels_\n",
    "train[['qid', 'topic']].to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fddbc925-5a52-4048-898d-efe49941fed7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
